import pandas as pd
import itertools
import random
from scipy.stats import bernoulli
import numpy as np
import copy
import traceback
import joblib


import os
import sys
sys.path.append("..")
sys.path.append("GerryFair")
import gerryfair


from util_functions_cbs_benchmarking import *
from multiaccuracy_funcs import * 

from conditional_bias_scan import ConditionalBiasScan
from cbs_preprocessor import CBSPreProcessor
from cbs_logger import CBSLogger
from yaml_funcs import YamlFunctions
from dataset_specific_funcs import DatasetSpecificFuncs
from sklearn import linear_model


import matplotlib.pyplot as plt
import time

from  scipy.stats import pearsonr

import multiprocessing as mp
import gc



if __name__ == "__main__":
    #number_of_instances = int(sys.argv[1])
    #sys.exit()
    array_id = int(os.environ["SLURM_ARRAY_TASK_ID"])
    #array_id = 53
    print(f"The recieved array id is: {array_id}")
    
    gc.collect()
    
    yaml_configs_path = "../fsscan_yamls/fsscan_configs-CBS_benchmark.yaml"
    yaml_funcs = YamlFunctions(yaml_configs_path)
    fsscan_configs = yaml_funcs.run()
    cbs_logger = CBSLogger(fsscan_configs["results_folder"])
    
    ## needed for benchmark tests
    dataset = None
    attributes = None
    centered = True
    #######
    
    for idx in range(0,1504):
        print(idx)
        
        all_info  = joblib.load('hpc_info/_'+str(array_id)+'_'+str(idx)+'.pkl')
    
        #print("loaded file")
    
        timestr = all_info['timestr']
        run_info = all_info['run_info']
    


    
        scan_params = run_info["scan_params"]
        #print(scan_params)
        #print(list(scan_params))
        scan = scan_params["scan_info"]
        dataset_yaml = scan_params["dataset_yaml"]
        data = scan_params["data"]
        p_bin_var = scan_params["p_bin_var"]
        tilde_probability_var = scan_params["tilde_probability_var"]

        df_t = scan_params["df_t"] 
        df_copy = scan_params["df_copy"] 

        s_bias = run_info["selected_bias_subset"]
        group_ind = run_info["group_ind"]
        key, key_value = run_info["protected_class"] 
        experiment_name = scan_params["experiment_name"]
        #print("printing scan type:::")
        #print(scan["scan_type"])
        #print(str(run_info["run_number"]))
        #print(str(run_info["mu"]))
        
        file_name = "Benchmark_results/_"+ timestr +"/bias_results"+ "/_run_num_" + str(run_info["run_number"]) + "_"+str(experiment_name)+"_"+ str(run_info["varying_parameter"])+"_"+scan["scan_type"]+"_sigmapred_"+ str(run_info["sigma"]) +"_mu_"+str(run_info["mu"])+".csv"
        
        if os.path.isfile(file_name) == False:
        
            if True:

                cbs = ConditionalBiasScan( scan["protected_class"], scan["protected_value"], scan["combo"], scan["event"] ,scan["conditional_variable"], fsscan_configs["fsscan_params"],scan["direction"], scan["feature_list"], scan["scan_type"] , scan["scan_feature_list"], scan["threshold_probability"], scan["threshold_cutoff"], printing_on= False )
                results =  cbs.run(dataset_yaml, data,p_bin_var, tilde_probability_var)

                stats_dict = cbs_logger.write_results(results["best_subset"], 
                                         results["best_score"], 
                                         results["best_param"], 
                                         results["treatment"], 
                                         results["treatment_events"], 
                                         results["treatment_p_hat"], 
                                         results["controls"],
                                         results["control_events"],
                                        results["control_conditional_var"],
                                        results["treatment_conditional_var"],
                                        results["dataset_yaml"],
                                        scan["protected_class"],
                                        scan["protected_value"], 
                                        scan["combo"],
                                        scan["event"],
                                        scan["conditional_variable"],
                                        fsscan_configs["fsscan_params"],
                                        scan["direction"],
                                        scan["feature_list"],
                                        scan["scan_type"],
                                        scan["scan_feature_list"],
                                        "", add_scores = True, include_conditional_var_base_rates = True )

                other_dict = cbs_logger.write_results(s_bias, 
                                         -1000000000, 
                                         np.inf, 
                                         results["treatment"], 
                                         results["treatment_events"], 
                                         results["treatment_p_hat"], 
                                         results["controls"],
                                         results["control_events"],
                                        results["control_conditional_var"],
                                        results["treatment_conditional_var"],
                                        results["dataset_yaml"],
                                        scan["protected_class"],
                                        scan["protected_value"], 
                                        scan["combo"],
                                        scan["event"],
                                        scan["conditional_variable"],
                                        fsscan_configs["fsscan_params"],
                                        scan["direction"],
                                        scan["feature_list"],
                                        scan["scan_type"],
                                        scan["scan_feature_list"], "_for_s_bias", add_scores = True, include_conditional_var_base_rates = True)

                del run_info["scan_params"]
                #experiment_name



                print("coefficients used for variable of logistic regression used to produce \hat p: ")
                print(results["p_hat_coefficient_mapping"])
               # print("tilde_p's coefficient is "+str(results["p_hat_coefficient_mapping"][scan["conditional_variable"]]))

                print("best subset found : " + str(results["best_subset"]))
                print("best score : " + str(results["best_score"]))
                print("param for best scoring subset : "+ str(results["best_param"]))

                s_found_subset = results["best_subset"]
                print(s_found_subset)
                print("accuracy:")
                print(compute_accuracy(df_t, s_bias,s_found_subset,group_ind))

                run_info["combo"] = scan["combo"]
                run_info["event"] = scan["event"]
                run_info["experiment_name"] = experiment_name
                run_info["conditional_variable"] = scan["conditional_variable"]
                run_info["fsscan_params"] = fsscan_configs["fsscan_params"]
                run_info["direction"] = scan["direction"]
                run_info["feature_list"] = scan["feature_list"]
                run_info["scan_type"] = scan["scan_type"] 
                run_info["scan_feature_list"] = scan["scan_feature_list"]
                run_info["threshold_probability"] = scan["threshold_probability"]
                run_info["threshold_cutoff"] = scan["threshold_cutoff"]


                run_info["best_subset"]  = results["best_subset"] 
                run_info["best_score"] = results["best_score"]
                run_info["best_param"] =  results["best_param"]

                run_info["cbs_accuracy"] =  compute_accuracy(df_t, s_bias,results["best_subset"], group_ind)
                run_info["cbs_precision"] =  compute_precision(df_t, s_bias,results["best_subset"], group_ind)
                run_info["cbs_recall"] =  compute_recall(df_t, s_bias,results["best_subset"], group_ind)
                run_info["cbs_param"] = results["best_param"]
                run_info["cbs_score"] = results["best_score"]
                run_info["p_hat_coefficient_mapping"] = results["p_hat_coefficient_mapping"]
                run_info["experiment_name"] = experiment_name

                    #### for prediction separation
                    ######## run benchmark testsw

            if "prediction_separation" in scan["scan_type"]:

                recommendations =  df_t["predicted_probs"]
                df_p = df_t[['Under 25', 'Prior Offenses', 'Race', 'ChargeDegree', 'Sex', 'ReoffendedWithinTwoYears', group_ind]]
                del df_p[key]


                # gerryfair - all features FP

                attributes_df = create_attributes_data_all_protected(df_p, 'ReoffendedWithinTwoYears')
                X, X_prime, y = gerryfair.clean.clean_dataset(dataset, attributes, centered, data = df_p , attributes_df = attributes_df)
                auditor = gerryfair.model.Auditor(X_prime, y, 'FP')
                [violated_group, fairness_violation, group] = auditor.audit(recommendations, under_estimation = True)

                run_info["gerryfair_all_features_accuracy_FP"] = compute_accuracy_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_precision_FP"] = compute_precision_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_recall_FP"] = compute_recall_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_fairness_violation_score_FP"] = fairness_violation

                # gerryfair - one feature FP

                attributes_df = create_attributes_data_one_protected(df_p, 'ReoffendedWithinTwoYears', group_ind)
                X, X_prime, y = gerryfair.clean.clean_dataset(dataset, attributes, centered, data = df_p , attributes_df = attributes_df)

                auditor = gerryfair.model.Auditor(X_prime, y, 'FP')
                [violated_group, fairness_violation, group] = auditor.audit(recommendations, under_estimation = True)
                run_info["gerryfair_one_features_accuracy_FP"] = compute_accuracy_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_precision_FP"] = compute_precision_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_recall_FP"] = compute_recall_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_fairness_violation_score_FP"] = fairness_violation

                # gerryfair - all features FN

                attributes_df = create_attributes_data_all_protected(df_p, 'ReoffendedWithinTwoYears')
                X, X_prime, y = gerryfair.clean.clean_dataset(dataset, attributes, centered, data = df_p , attributes_df = attributes_df)
                auditor = gerryfair.model.Auditor(X_prime, y, 'FN')
                [violated_group, fairness_violation, group] = auditor.audit(recommendations, under_estimation = True)

                run_info["gerryfair_all_features_accuracy_FN"] = compute_accuracy_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_precision_FN"] = compute_precision_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_recall_FN"] = compute_recall_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_all_features_fairness_violation_score_FN"] = fairness_violation

                # gerryfair - one feature FN

                attributes_df = create_attributes_data_one_protected(df_p, 'ReoffendedWithinTwoYears', group_ind)
                X, X_prime, y = gerryfair.clean.clean_dataset(dataset, attributes, centered, data = df_p , attributes_df = attributes_df)

                auditor = gerryfair.model.Auditor(X_prime, y, 'FN')
                [violated_group, fairness_violation, group] = auditor.audit(recommendations, under_estimation = True)
                        #print(group)
                run_info["gerryfair_one_features_accuracy_FN"] = compute_accuracy_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_precision_FN"] = compute_precision_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_recall_FN"] = compute_recall_gerryfair(df_p, violated_group, s_bias,group_ind)
                run_info["gerryfair_one_features_fairness_violation_score_FN"] = fairness_violation

                # run multiaccuracy

                print("Running multi-accuracy")

                outcomes = df_p["ReoffendedWithinTwoYears"]
                del df_p["ReoffendedWithinTwoYears"]

                res_data = run_multiaccuracy(df_p, df_t["predicted_log_odds"], outcomes, "FP")
                corr, ma_id = find_highest_corr(res_data)

                run_info["multiaccuracy_accuracy"] = compute_accuracy_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_precision"] = compute_precision_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_recall"] = compute_recall_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_score"] = corr

                res_data = run_multiaccuracy_org(df_p, df_t["predicted_log_odds"], outcomes, "FP")
                corr, ma_id = find_highest_corr(res_data)

                run_info["multiaccuracy_accuracy_org"] = compute_accuracy_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_precision_org"] = compute_precision_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_recall_org"] = compute_recall_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_score_org"] = corr

                res_data = run_multiaccuracy_exp(df_p, df_t["predicted_log_odds"], outcomes, "FP")
                corr, ma_id = find_highest_corr(res_data)

                run_info["multiaccuracy_accuracy_exp"] = compute_accuracy_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_precision_exp"] = compute_precision_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_recall_exp"] = compute_recall_gerryfair(df_p, list(ma_id), s_bias,group_ind)
                run_info["multiaccuracy_score_exp"] = corr

                #run_info_deep_copy = copy.deepcopy(run_info)
                #run_infos.append(run_info_deep_copy)

            run_info = {**run_info , **stats_dict}
            run_info = {**run_info , **other_dict}

            file_name = "Benchmark_results/_"+ timestr +"/bias_results"+ "/_run_num_" + str(run_info["run_number"]) + "_"+str(experiment_name)+"_"+ str(run_info["varying_parameter"])+"_"+scan["scan_type"]+"_sigmapred_"+ str(run_info["sigma"]) +"_mu_"+str(run_info["mu"])+".csv"
            pd.DataFrame([run_info]).to_csv(file_name)
                


   